from typing import Any

import dotenv
from openai import OpenAI

dotenv.load_dotenv()


class CustomConversationSummaryBufferMemory(object):

    def __init__(self, summary:str="", chat_histories:list=None, max_tokens:int=300):

        """
        初始化方法
        1.summary 摘要信息
        2.chat_histories 消息历史
        3.max_tokens 最大上下文长度，包括输入输出
        4._client 客户端对象
        """
        self.summary = summary
        self.chat_histories = chat_histories
        self.max_tokens = max_tokens
        self._client = OpenAI()

    @classmethod
    def get_num_tokens(cls, query:str)->int:
        """
        计算查询的token数
        """
        return len(query)

    def save_context(self,human_query:str,ai_content:str)->None:
        # 将用户查询和ai的回答信息追加到历史消息中
        self.chat_histories.append(
            {"human":human_query,"ai":ai_content}
        )
        # 获取当前历史记录缓冲的数据大小
        buffer_string = self.get_buffer_string()
        tokens = self.get_num_tokens(buffer_string)
        # 如果超过了最大上下文，进行摘要
        if tokens > self.max_tokens:
            # 选择第一条历史生成摘要
            first_chat = self.chat_histories[0]
            print("新摘要生成中~")
            self.summary = self.summary_text(
                self.summary,
                f"Human:{first_chat.get('human')}\nAI:{first_chat.get('ai')}"
            )
            print("新摘要生成成功:",self.summary)
            del  self.chat_histories[0]

    def get_buffer_string(self)->str:
        """
        用于计算当前缓冲区的字符串
        """
        buffer:str=''
        for chat in self.chat_histories:
            buffer += f"Human:{chat.get('human')}\nAI:{chat.get('ai')}\n\n"
        return buffer.strip()

    def load_memory_variables(self)->dict[str,Any]:
        """
        加载记忆变量为一个字典，便于格式化到prompt中
        """
        buffer_string = self.get_buffer_string()
        return {
            "chat_history":f"摘要:{self.summary}\n\n历史信息:{buffer_string}"
        }

    def summary_text(self,origin_summary:str,new_line:str)->str:
        """
        用于将旧的摘要和传入的新的会话生成一个新的摘要
        """
        prompt = f"""
        你是一个强大的聊天机器人，请根据用户提供的谈话内容，总结摘要，并将其添加到先前提供的摘要中，返回一个新的摘要，除了新摘要其他任何数据都不要生成，如果用户的对话信息里有一些关键的信息，比方说姓名、爱好、性别、重要事件等等，这些全部都要包括在生成的摘要中，摘要尽可能要还原用户的对话记录。
        
        请不要将<example>标签里的数据当成实际的数据，这里的数据只是一个示例数据，告诉你该如何生成新摘要。

        <example>
        当前摘要：人类会问人工智能对人工智能的看法，人工智能认为人工智能是一股向善的力量。
        
        新的对话：
        Human：为什么你认为人工智能是一股向善的力量？
        AI：因为人工智能会帮助人类充分发挥潜力。
        
        新摘要：人类会问人工智能对人工智能的看法，人工智能认为人工智能是一股向善的力量，因为它将帮助人类充分发挥潜力。
        </example>
        
        =====================以下的数据是实际需要处理的数据=====================
        当前摘要：{origin_summary}
        新的对话：
        {new_line}
        
        请帮用户将上面的信息生成新摘要。
        """
        completion = self._client.chat.completions.create(
            model= "gpt-4-turbo",
            messages=[
                {"role":"user","content":prompt},
            ]
        )
        return completion.choices[0].message.content


if __name__ == '__main__':
    client = OpenAI()
    memory = CustomConversationSummaryBufferMemory("",[],300)

    # 死循环测试记忆能力
    while True:
        query = input("Human: ")

        # 是否输入了终止符
        if query == 'q':
            break

        memory_variables = memory.load_memory_variables()

        answer_prompt = (
            "你是一个强大的聊天机器人，请根据对应的上下文和用户提问解决问题。\n\n"
            f"{memory_variables.get('chat_history')}\n\n"
            f"用户的提问是: {query}"
        )
        response = client.chat.completions.create(
            model='gpt-4-turbo',
            messages=[
                {"role":"user","content":answer_prompt},
            ],
            stream=True,
        )

        print("AI: ", flush=True, end="")

        ai_content = ""
        # 流式输出的情况下，遍历每一次输出作为一个块
        for chunk in response:
            # 这里获取每个块的delta的内容content
            content = chunk.choices[0].delta.content
            # 如果ai输出结束了，那么这里就停止循环
            if content is None:
                break
            # 拼接上输出的内容
            ai_content += content
            # 打印一下内容
            print(content,flush=True,end="")
        print("")
        # 这里进行历史记录的保存，并且根据设置的上下文长度，进行摘要
        memory.save_context(query,ai_content)